"""
Phase 1: Data Assembly — Automation & Capital Intensity Variables
=================================================================
Construct automation proxies from existing full_panel.csv data, attempt
to download PWT 10.01 for labor share (labsh) and capital stock (rnna),
and merge into automation_panel.csv. Write summary statistics.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import io
import warnings
warnings.filterwarnings('ignore')

try:
    import requests
except ImportError:
    requests = None

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
MULTILATERAL_DIR = ROOT_DIR / "multilateral"

OUT_DATA = PROJECT_DIR / "data" / "processed"
OUT_RAW = PROJECT_DIR / "data" / "raw"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_DATA.mkdir(parents=True, exist_ok=True)
OUT_RAW.mkdir(parents=True, exist_ok=True)
OUT_TABLES.mkdir(parents=True, exist_ok=True)

OECD = [
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
]

PWT_URL = "https://dataverse.nl/api/access/datafile/354098"
PWT_CACHE = OUT_RAW / "pwt1001.csv"


# ── PWT Download ─────────────────────────────────────────────────────

def download_pwt(force=False):
    """Download and cache PWT 10.01 CSV."""
    if PWT_CACHE.exists() and not force:
        print(f"  Using cached PWT: {PWT_CACHE}")
        return PWT_CACHE

    if requests is None:
        print("  WARNING: requests not available, skipping PWT download")
        return None

    print("  Downloading PWT 10.01...")
    try:
        resp = requests.get(PWT_URL, timeout=120)
        if resp.status_code != 200:
            print(f"  WARNING: PWT download failed (status {resp.status_code})")
            return None
        PWT_CACHE.write_bytes(resp.content)
        print(f"  Saved: {PWT_CACHE} ({len(resp.content)/1024:.0f} KB)")
        return PWT_CACHE
    except Exception as e:
        print(f"  WARNING: PWT download error: {e}")
        return None


def parse_pwt(csv_path):
    """Parse PWT data (CSV or Stata .dta) for labsh, rnna, emp, countrycode, year."""
    print(f"  Parsing PWT: {csv_path}")
    # Try reading as Stata .dta first (PWT 10.01 may be Stata format)
    try:
        df = pd.read_stata(csv_path)
        print(f"  Read as Stata .dta format")
    except Exception:
        try:
            df = pd.read_csv(csv_path, encoding='utf-8-sig')
        except UnicodeDecodeError:
            df = pd.read_csv(csv_path, encoding='latin-1')

    # Identify columns (PWT uses lowercase)
    cols = df.columns.tolist()
    print(f"  PWT columns (first 20): {cols[:20]}")

    # Standardize column names
    col_map = {}
    for c in cols:
        cl = c.lower().strip()
        if cl == 'countrycode':
            col_map[c] = 'iso3'
        elif cl == 'year':
            col_map[c] = 'year'
        elif cl == 'labsh':
            col_map[c] = 'labsh'
        elif cl == 'rnna':
            col_map[c] = 'rnna'
        elif cl == 'emp':
            col_map[c] = 'emp'

    df = df.rename(columns=col_map)

    keep = [c for c in ['iso3', 'year', 'labsh', 'rnna', 'emp'] if c in df.columns]
    if 'iso3' not in keep or 'year' not in keep:
        print(f"  WARNING: PWT missing iso3/year columns. Available: {keep}")
        return None

    df = df[keep].copy()
    df['year'] = pd.to_numeric(df['year'], errors='coerce')
    for c in ['labsh', 'rnna', 'emp']:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors='coerce')

    # Compute capital per worker if possible
    if 'rnna' in df.columns and 'emp' in df.columns:
        df['capital_per_worker'] = df['rnna'] / (df['emp'] * 1e6)
        df.loc[df['capital_per_worker'] <= 0, 'capital_per_worker'] = np.nan
        print(f"  Computed capital_per_worker: "
              f"{df['capital_per_worker'].dropna().shape[0]} obs")

    # Compute automation proxy = 1 - labsh (capital share)
    if 'labsh' in df.columns:
        df['automation_proxy'] = 1 - df['labsh']
        print(f"  Computed automation_proxy (1 - labsh): "
              f"{df['automation_proxy'].dropna().shape[0]} obs")

    print(f"  PWT parsed: {len(df)} obs, {df['iso3'].nunique()} countries")
    return df


# ── Proxy Construction ───────────────────────────────────────────────

def construct_proxies(panel):
    """Construct automation proxies from existing full_panel variables."""
    print("\n--- Constructing Automation Proxies ---")

    # 1. Capital intensity = gross_investment_gdp (investment/GDP ratio)
    if 'gross_investment_gdp' in panel.columns:
        panel['capital_intensity'] = panel['gross_investment_gdp']
        n = panel['capital_intensity'].dropna().shape[0]
        print(f"  capital_intensity (investment/GDP): {n} obs")
    else:
        print("  WARNING: gross_investment_gdp not available")
        panel['capital_intensity'] = np.nan

    # 2. Labor productivity proxy = gdp_pc_ppp
    if 'gdp_pc_ppp' in panel.columns:
        panel['labor_productivity'] = panel['gdp_pc_ppp']
        n = panel['labor_productivity'].dropna().shape[0]
        print(f"  labor_productivity (GDP per capita PPP): {n} obs")
    else:
        print("  WARNING: gdp_pc_ppp not available")
        panel['labor_productivity'] = np.nan

    # 3. Log labor productivity
    panel['log_labor_productivity'] = np.log(
        panel['labor_productivity'].clip(lower=100)
    )
    panel.loc[panel['labor_productivity'].isna(), 'log_labor_productivity'] = np.nan

    # 4. GVC proxy = trade_openness (trade/GDP)
    if 'trade_openness' in panel.columns:
        panel['gvc_proxy'] = panel['trade_openness']
        n = panel['gvc_proxy'].dropna().shape[0]
        print(f"  gvc_proxy (trade/GDP): {n} obs")
    else:
        print("  WARNING: trade_openness not available")
        panel['gvc_proxy'] = np.nan

    # 5. OECD indicator
    panel['oecd'] = panel['iso3'].isin(OECD).astype(int)
    print(f"  OECD countries in panel: "
          f"{panel.loc[panel['oecd'] == 1, 'iso3'].nunique()}")

    return panel


# ── Summary Statistics ───────────────────────────────────────────────

def write_summary_statistics(panel):
    """Write summary statistics table for automation variables."""
    auto_vars = [
        'capital_intensity', 'labor_productivity', 'log_labor_productivity',
        'gvc_proxy', 'labsh', 'automation_proxy', 'capital_per_worker',
        'Z_1', 'Z_2', 'Z_3', 'ca_gdp', 'old_dep', 'youth_dep',
        'working_age_share', 'kaopen',
    ]

    rows = []
    for var in auto_vars:
        if var in panel.columns:
            s = panel[var].dropna()
            if len(s) == 0:
                continue
            rows.append({
                'Variable': var,
                'N': len(s),
                'Mean': f"{s.mean():.4f}",
                'Std': f"{s.std():.4f}",
                'Min': f"{s.min():.4f}",
                'Max': f"{s.max():.4f}",
            })

    lines = ["# Summary Statistics: Automation & Capital Intensity Panel\n"]
    lines.append("| Variable | N | Mean | Std | Min | Max |")
    lines.append("|:---|---:|---:|---:|---:|---:|")
    for r in rows:
        lines.append(f"| {r['Variable']} | {r['N']} | {r['Mean']} "
                      f"| {r['Std']} | {r['Min']} | {r['Max']} |")

    lines.append(f"\n*Panel: {panel['iso3'].nunique()} countries, "
                 f"{panel['year'].min()}--{panel['year'].max()}*")

    # PWT availability note
    pwt_vars = [v for v in ['labsh', 'automation_proxy', 'capital_per_worker']
                if v in panel.columns and panel[v].dropna().shape[0] > 0]
    if pwt_vars:
        lines.append(f"\n*PWT 10.01 variables available: {", ".join(pwt_vars)}*")
    else:
        lines.append("\n*PWT 10.01 download failed; using proxy variables only.*")

    path = OUT_TABLES / "summary_statistics.md"
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ── Main ─────────────────────────────────────────────────────────────

def main():
    print("=" * 70)
    print("PHASE 1: DATA ASSEMBLY — AUTOMATION & CAPITAL INTENSITY")
    print("=" * 70)

    # Load base panel
    fp_path = MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv"
    print(f"\n  Loading: {fp_path}")
    panel = pd.read_csv(fp_path)
    panel = panel[panel['year'] <= 2024].copy()
    print(f"  Panel: {len(panel)} obs, {panel['iso3'].nunique()} countries, "
          f"{panel['year'].min()}--{panel['year'].max()}")

    # Attempt PWT download
    print("\n--- PWT 10.01 Download ---")
    pwt_path = download_pwt()

    pwt_df = None
    if pwt_path is not None and pwt_path.exists():
        try:
            pwt_df = parse_pwt(pwt_path)
        except Exception as e:
            print(f"  ERROR parsing PWT: {e}")
            pwt_df = None

    # Merge PWT if available
    if pwt_df is not None:
        pre_n = len(panel)
        pwt_cols = [c for c in pwt_df.columns if c not in ['iso3', 'year']]
        # Drop any PWT columns already in panel to avoid conflicts
        for c in pwt_cols:
            if c in panel.columns:
                panel = panel.drop(columns=[c])
        panel = panel.merge(pwt_df[['iso3', 'year'] + pwt_cols],
                            on=['iso3', 'year'], how='left')
        print(f"\n  PWT merge: {pre_n} -> {len(panel)} obs")
        for c in pwt_cols:
            n = panel[c].dropna().shape[0]
            print(f"    {c}: {n} non-missing")
    else:
        print("\n  PWT data not available — using proxy variables only")

    # Construct proxies
    panel = construct_proxies(panel)

    # Report coverage
    print("\n--- Variable Coverage ---")
    key_vars = ['capital_intensity', 'labor_productivity', 'gvc_proxy',
                'labsh', 'automation_proxy', 'capital_per_worker']
    for v in key_vars:
        if v in panel.columns:
            n = panel[v].dropna().shape[0]
            nc = panel.loc[panel[v].notna(), 'iso3'].nunique()
            print(f"  {v:25s}: {n:5d} obs, {nc:3d} countries")

    # Summary statistics
    write_summary_statistics(panel)

    # Save
    out_path = OUT_DATA / "automation_panel.csv"
    panel.to_csv(out_path, index=False)
    print(f"\n  Saved: {out_path}")
    print(f"  Shape: {panel.shape}")

    print("\n" + "=" * 70)
    print("PHASE 1 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
